from typing import Optional
import os
import yaml

torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
assert torchao_root is not None, "TORCHAO_ROOT is not set"

MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")

# Path to yaml file containing the list of .metal files to include
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")

metal_files = set()
with open(METAL_YAML, "r") as yamlf:
    metal_config = yaml.safe_load(yamlf)
    for op in metal_config:
        if "file" in op:
            metal_files.add(op["file"])
metal_files = sorted(metal_files)

# Path to the folder containing the .metal files
METAL_DIR = os.path.join(MPS_DIR, "metal")

# Output file where the generated code will be written
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")

prefix = """/**
 * This file is generated by gen_metal_shader_lib.py
 */

#ifdef ATEN
using namespace at::native::mps;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
"""

suffix = """
)METAL_LOWBIT");
"""

comment = """
/**
 * Contents of {}
 */

"""

with open(OUTPUT_FILE, "w") as outf:
    outf.write(prefix)
    for file in metal_files:
        with open(os.path.join(METAL_DIR, file), "r") as f:
            content = f.read()
            outf.write(comment.format(file))
            outf.write(content)
            outf.write("\n\n")
    outf.write(suffix)
